Diving deep: learning rate

Date: June 4, 2020

Table of contents:

Here, I'm just trying to test out the expected scale of many types of networks, to see what range is sort of the default that you should try first, and if there are any interesting observations along the way. Initially, I wanted to sort of make it so that the network have control over its own learning rate, so I'm interested in digging more about this.

Simple functions network

This is a network taken from a previous post. Its mission is to predict a simple mathematical function when given multiple $(x, y)$ pairs. I pretty much just copy the code over and shorten it so it looks cleaner, but the functionality is basically the same.

In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader

Here we have our custom dataset class, with several functions available. Also note that the previous total possible samples is 10001, we are going to use 2501 (just below $1280\cdot2$):

In [3]:
class FunctionDataset(Dataset):
    def __init__(self, function: callable, start: float=-5, stop: float=5, samples: int=300):
        self.function = function; self.start = start; self.stop = stop; self.samples = samples
    def __len__(self): return self.samples
    def __getitem__(self, index):
        x = index/self.samples * (self.stop - self.start) + self.start
        return x, self.function(x)
In [112]:
bs = 1280; sampl = 2501; expF = lambda x: torch.exp(x); expNF = lambda x: torch.exp(-x)
logF = lambda x: torch.log(x); invF = lambda x: 1 / x; linF = lambda x: 2 * x + 8
sinF = lambda x: torch.sin(x); stepF = lambda x: x > 0
expDl = DataLoader(FunctionDataset(lambda x: np.exp(x), samples=sampl), batch_size=bs)
expNDl = DataLoader(FunctionDataset(lambda x: np.exp(-x), samples=sampl), batch_size=bs)
exp7Dl = DataLoader(FunctionDataset(lambda x: np.exp(x), samples=sampl, stop=7), batch_size=bs)
logDl = DataLoader(FunctionDataset(lambda x: np.log(x), start=0.01, samples=sampl), batch_size=bs)
invDl = DataLoader(FunctionDataset(invF, samples=sampl), batch_size=bs)
linDl = DataLoader(FunctionDataset(linF, samples=sampl), batch_size=bs)
sinDl = DataLoader(FunctionDataset(lambda x: np.sin(x), samples=sampl), batch_size=bs)
stepDl = DataLoader(FunctionDataset(stepF, samples=sampl), batch_size=bs)

Here we have the network definition. This is broken up into multiple cells so that it'd be easier to run and navigate around the notebook. Also I have made a lot of NN's function return itself, so that we can stack it and save more space:

In [7]:
# simple (#batch, 1) to (#batch, 1)
class NN(nn.Module):
    def __init__(self, hiddenDim=10, dropout_p=0, activation=nn.ReLU()):
        super().__init__()
        self.fc_begin = nn.Linear(1, hiddenDim); self.fc1 = nn.Linear(hiddenDim, hiddenDim)
        self.fc2 = nn.Linear(hiddenDim, hiddenDim); self.fc_end = nn.Linear(hiddenDim, 1)
        self.activation = activation; self.totalLosses = []; self.dropout = nn.Dropout(dropout_p)
    def forward(self, x):
        x = self.dropout(self.activation(self.fc_begin(x)))
        x = self.dropout(self.activation(self.fc1(x)))
        x = self.dropout(self.activation(self.fc2(x)))
        return self.fc_end(x)
    def plotLosses(self, begin=0, end=0):
        plt.figure(num=None, figsize=(10, 3), dpi=350)
        if end == 0: end = len(self.totalLosses)
        plt.plot(range(len(self.totalLosses))[begin:end], self.totalLosses[begin:end])
        plt.legend(["Loss"]); plt.show(); return self
In [183]:
def train(self, dl, optimizer=None, lr=0.01, epochs=500, trackLosses=True):
    lossFunction=nn.MSELoss()
    if optimizer == None: optimizer = optim.Adam(self.parameters(), lr=lr)
    for epoch in range(epochs):
        totalLoss = 0
        for x, y in dl:
            optimizer.zero_grad(); output = self.forward(x.view(-1, 1).float().cuda())
            loss = lossFunction(output, y.view(-1, 1).float().cuda())
            loss.backward(); totalLoss += loss.item(); optimizer.step()
        totalLoss /= dl.batch_size
        if trackLosses: self.totalLosses.append(totalLoss)
    return self
NN.train = train

Hopefully you remember that this function plots the predicted output given some input, and it divides the output into multiple colored segments because I thought it would give us some intuition about the network exhaustion point. If you don't get this, don't worry, we will see the output of this soon.

In [119]:
def plot(self, x, function: callable=None):
    plt.figure(num=None, figsize=(10, 3), dpi=350); x = x.view(-1, 1)
    if function != None: plt.plot(x.cpu(), function(x).cpu(), "--")
    x, colors = x.cuda(), ["r", "g", "b", "c", "k"]
    last, segmentNumber, epsilon = 1, 0, 1e-3#/(x[1]-x[0])/30
    y = self(x).detach().view(-1); x = x.view(-1)
    gradients = (y - y.roll(1))/(x - x.roll(1))
    transition = torch.abs(gradients - gradients.roll(1)) > epsilon
    for step in range(2, len(x)):
        if transition[step]:
            color = colors[segmentNumber%len(colors)]
            if segmentNumber%len(colors) == 1:
                plt.plot(x[last-1:step].cpu(), y[last-1:step].cpu(), "-"+color, lw=4)
            plt.plot(x[last-1:step].cpu(), y[last-1:step].cpu(), "-"+color)
            segmentNumber += 1; last = step
    else:
        color = colors[segmentNumber%len(colors)]
        if segmentNumber%len(colors) == 1:
            plt.plot(x[last-1:step].cpu(), y[last-1:step].cpu(), "-"+color, lw=4)
        plt.plot(x[last-1:step].cpu(), y[last-1:step].cpu(), "-"+color)
    plt.legend(["Real", "Learned"] if function != None else ["Learned"]); plt.show();return self
NN.plot = plot

Here is the first new bits. This is so that we can effectively save parameters at 1 point in time, then train it on something, and then recover the initial state:

In [8]:
def importParams(self, params):
    oldParams = list(self.parameters())
    for i in range(len(params)): oldParams[i].data = params[i].data.clone()
NN.importParams = importParams
def exportParams(self):
    params = []
    for param in self.parameters(): params.append(param.data)
    return params
NN.exportParams = exportParams

This is a function that gets the loss when passed in a DataLoader. Pretty standard.

In [121]:
def getLoss(self, dl:DataLoader):
    lossFunction=nn.MSELoss(); totalLoss = 0
    for x, y in dl:
        output = self.forward(x.view(-1, 1).float().cuda())
        totalLoss += lossFunction(output, y.view(-1, 1).float().cuda()).item()
    return totalLoss / dl.batch_size
NN.getLoss = getLoss

This is a function to test several learning rates at the same time, training for a number of epochs with each learning rate and then measure its performance. It stops prematurely if the loss explodes too much. Finally, it returns log-distributed learning rates and their respective losses:

In [222]:
def findLr(self, dl:DataLoader, epochs:int=2, begin:float=1e-6, end:float=100, steps:int=100, full:bool=False):
    lrs = torch.exp(torch.linspace(np.log(begin), np.log(end), steps))
    losses = []; initialLoss = self.getLoss(dl); initialParams = self.exportParams()
    for lr in lrs:
        self.train(dl, lr=lr, epochs=epochs, trackLosses=False); loss = self.getLoss(dl)
        losses.append(loss); self.importParams(initialParams);
        if not full:
            if loss > initialLoss * 1.1: losses.pop(); break
    return lrs[:len(losses)], torch.Tensor(losses)
NN.findLr = findLr

This is a convenience function that will takes in a dataset (more precisely a DataLoader), creates a new network, finds the best learning rate using NN.findLr() for 1, 2, 4, and 8 epochs, then graph the results so they look nice.

In [227]:
def LRFinder(dl:DataLoader, net:NN=None, begin:float=1e-6, end:float=100, full:bool=False):
    net = NN().cuda() if net is None else net
    plt.figure(figsize=(10, 6), dpi=350); initialParams = net.exportParams()
    for i in range(4):
        plt.subplot(2, 2, i+1); plt.title(f"Epoch {2**i}"); plt.grid(True)
        plt.xscale("log"); plt.plot(*net.findLr(dl, int(2**i), begin=begin, end=end, full=full))
        net.importParams(initialParams); plt.xlabel("Learning rate"); plt.ylabel("Loss")
        plt.tight_layout(pad=1.0)

Let's first do a vanilla run, to make sure everything is running:

In [125]:
NN().cuda().train(logDl, epochs=200).plot(torch.linspace(0.01, 7, 300)).plotLosses(); pass

Yep, seems like everything is running just fine. Let's test out our LRFinder on a number of functions:

In [106]:
LRFinder(logDl)

Let's try to learn at 2 learning rates. 1 is at $10^{-3}$, when the network can just barely move its butt over, and the other at $5\cdot10^{-2}$ (remember this is log scale, so $5\cdot10^{-2}$ is closer to $10^{-1}$ than $10^{-2}$), when the network dips pretty low, but not too close to the the point where it blows up:

In [139]:
NN().cuda().train(logDl,lr=1e-3,epochs=30).plot(torch.linspace(0.01,7,300)).plotLosses(); pass
In [148]:
net=NN().cuda().train(logDl,lr=5e-2,epochs=30).plot(torch.linspace(0.01,7,300)).plotLosses()

At 30 epochs (which is pretty low), the $5\cdot10^{-2}$ learning rate network does pretty well at its job, meanwhile the $10^{-3}$ learning rate hardly learns anything at all. Let's see if we run it for 200 epochs, will things change?

In [150]:
net=NN().cuda().train(logDl,lr=1e-3,epochs=200).plot(torch.linspace(0.01,7,300)).plotLosses()

This is actually progressing somewhere, but it still doesn't look like a log function. At around 350-500 epochs, it finally looks like it has learned it. A network with the wrong learning rate can take $425/30=14.16$ more time than a network with the right one. Let's check out other functions:

In [107]:
LRFinder(expDl)
In [126]:
LRFinder(sinDl)
In [109]:
LRFinder(linDl)
In [110]:
LRFinder(stepDl)

It varies quite a lot between functions, but the general range that seems to get things going and not blowing up is from $10^{-2}$ to $10^{-1}$, so $3.2\cdot10^{-2}$ really seem like a sweet spot ($\sqrt{10}\approx3.2$)

In [161]:
LRFinder(invDl)

The inverse case however, doesn't seem like we can gain any insights at all. When we look deep enough at epoch 8, then there does seem to be a dip from $10^{-2}$ to $10^{-1}$. It can still learn though:

In [166]:
net=NN().cuda().train(invDl,lr=3.2e-2,epochs=80).plot(torch.linspace(-0.5,0.5,300)).plotLosses()

What this demonstrates to us is that for normal, non-extreme functions, we can gauge the right learning rate pretty well only by looking ahead 2-4 epochs, but we shouldn't assume every function to be like that.

After training

Let's see whether the learning rate landscape changes after leaning for a while:

In [264]:
net=NN().cuda(); LRFinder(logDl, net); lr=3.2e-2; print(f"Learning rate: {lr}")
net.train(logDl,lr=lr,epochs=50).plot(torch.linspace(0.01,7,300), logF).plotLosses()
LRFinder(logDl, net); lr=1e-3; print(f"Learning rate: {lr}")
net.train(logDl,lr=lr,epochs=150).plot(torch.linspace(0.01,7,300), logF).plotLosses()
LRFinder(logDl, net); lr=1e-4; print(f"Learning rate: {lr}")
net.train(logDl,lr=lr,epochs=50).plot(torch.linspace(0.01,7,300), logF).plotLosses()
LRFinder(logDl, net)
Learning rate: 0.032
Learning rate: 0.001
Learning rate: 0.0001

Did you see that? Initially the sweet spot is what we predicted earlier, from $10^{-2}$ to $10^{-1}$. However, when it has already been sufficiently trained, the sweet spot moves to $10^{-4}$ to $10^{-3}$, over 2 orders of magnitude! Also it's pretty interesting how for the last 2 learning rate scan, $10^{-3}$ generally trains the network better in the long run (8 epochs), despite this seems to not be the case for 1 or 2 epoch, as $10^{-3}$ seems like it will explode the network.

What should we expect if we have kept using a learning rate of $3.2\cdot10^{-2}$ for a long time? It should decides that $3.2\cdot10^{-2}$ to be too large, and so it will blow up. Once it has blown up, the network will realizes that it's not doing well, and will adjust the sweet spot so as to accommodate the learning rate value. It should do this a couple of times and shouldn't be able to really decrease the loss. Let's test it out:

In [287]:
net1 = NN().cuda().train(logDl, lr=3.2e-2, epochs=500).plotLosses()

Hopefully you can see those little bumps popping up here and there. Let's try to graph from epoch 50 onwards:

In [288]:
net1.plotLosses(begin=50); pass

Yep, it's spiking every now and then, but doesn't really follow any particular pattern. Let's see a plain point and see the actual loss:

In [293]:
net1.plotLosses(begin=330, end=415); pass

$2.1\cdot10^{-6}$. Seems like it's doing well. Let's analyze the initial network trained with many different learning rates:

In [296]:
net.plotLosses(begin=75); pass

No more spikes! Remember that the learning rate is $10^{-3}$ from epoch 50 to 200, and at epoch 200, you can see there's an abrupt slowdown due to the learning rate being $10^{-4}$. Let's see the plain point afterwards:

In [297]:
net.plotLosses(begin=230); pass

$1.575\cdot10^{-6}$. Also doing well, but that's just $\frac{2.1}{1.575}=1.3$ times better than the $3.2\cdot10^{-2}$ learning rate network. So we can conclude that we indeed should lower the learning rate over time, for stability and to prevent us from getting an erroneous network rather than dramatically improving the result.

An interesting question might be that why can $3.2\cdot10^{-2}$ learning rate's performance get anywhere near $10^{-3}$ learning rate? It seems like the $3.2\cdot10^{-2}$ network will just jump around so much that it can never settle. There are actually 2 parts to this. First is the fact that linear layers can generate fantastically high dimensional terrain. This high dimensionality sort of makes every possible local minima region into a saddle region. For each dimension, a pocket of space can either be a local minima or local maxima. Let's say the probability of either is $0.5$. Thus, for an $n$ dimensional space, the probability for the pocket of space to all be a local minima is $0.5^n$ (aka very, very low):

In [311]:
x = torch.arange(0, 20, step=1.); plt.plot(x, 0.5**x); pass

Eventually, there will be more local minima than local maxima, then the probability shifts toward the network actually landing in a true local minima (local minima in all dimensions). However, that local minima will actually be quite close to the global minima, because the network can sort of "tunnel" through the landscape, because of its high dimensionality again, and arrive at a better local minima if the global minima is far better than the current local minima.

The second fact is that we are pretty nitpicky about the performance. In the grand scheme of things, the loss stayed very, very low throughout all of the spikes. As soon as it leaves a spike, it immediately snaps back to its original loss and stay there for a while. This is actually a problem with the question, and we might as well just smooth those bumps out and call it a day.

One last point I want to touch on is about the phenomenon we talked about earlier, where for 1 and 2 epoch learning rate scans, it seems like $10^{-3}$ will blow the network up, and so it's better to use $10^{-4}$. However, when we train with learning rate $10^{-3}$ for a long time, the loss decreases very, very smoothly over time and there are no spikes whatsoever. Furthermore, the scans follow very bizzarre patterns. For 1 epoch, it blows up a little. Then for 2, it blows up a lot. But then for 4 epochs, it returns to blowing up a little and for 8 epochs, it's the sweetest spot there is. This behavior is highly mysterious, I still don't understand why it does that and there should be more research going into it.

Let's now investigate another network:

CIFAR-10 network

This is also a network taken from a previous post. For this one, I'm just pretty much copying over every function there:

In [2]:
import torchvision.datasets as datasets
import torchvision
import torchvision.transforms as transforms
import time
from PIL import Image
In [3]:
means = torch.Tensor([0.4914, 0.4822, 0.4465])
stds = torch.Tensor([0.2470, 0.2435, 0.2616])
transforms = transforms.Compose([transforms.RandomHorizontalFlip(), 
                                 transforms.RandomRotation(10), 
                                 transforms.ToTensor(), 
                                 transforms.Normalize(means, stds)])
dl = torch.utils.data.DataLoader(datasets.CIFAR10("datasets/", download=True, 
                                                  train=True, transform=transforms), 
                                 batch_size=100, 
                                 shuffle=True)
dl_test = torch.utils.data.DataLoader(datasets.CIFAR10("datasets/", download=True, 
                                                       train=False, transform=transforms), 
                                      batch_size=100, 
                                      shuffle=True)
stds = stds.unsqueeze(-1).unsqueeze(-1).expand(-1, 32, 32)
means = means.unsqueeze(-1).unsqueeze(-1).expand(-1, 32, 32)
Files already downloaded and verified
Files already downloaded and verified
In [4]:
categories = ["plane", "auto", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
imgs, labels = next(iter(dl)); plt.figure(num=None, figsize=(10, 3), dpi=350)
for i in range(20):
    plt.subplot(2, 10, i+1); plt.imshow((imgs[i] * stds + means).permute(1, 2, 0))
    plt.title(categories[labels[i]]); plt.axis("off")
imgs.shape, labels.shape
Out[4]:
(torch.Size([100, 3, 32, 32]), torch.Size([100]))
In [33]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1) # 32x32 -> 32x32
        self.conv2 = nn.Conv2d(32, 32, 3, padding=1) # 32x32 -> 16x16
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1) # 16x16 -> 16x16
        self.conv4 = nn.Conv2d(64, 64, 3, padding=1) # 16x16 -> 8x8
        self.conv5 = nn.Conv2d(64, 128, 3, padding=1) # 8x8 -> 4x4
        self.pool = nn.MaxPool2d(2); self.relu = nn.ReLU(); self.logSoftmax = nn.LogSoftmax(1)
        self.batchnorm1 = nn.BatchNorm2d(32)
        self.batchnorm2 = nn.BatchNorm2d(64)
        self.fc1 = nn.Linear(128 * 4 * 4, 300); self.fc2 = nn.Linear(300, 10)
        self.dropout = nn.Dropout(0.5)
        self.times, self.losses, self.accuracies = [], [], []
    def forward(self, x):
        x = self.batchnorm1(self.pool(self.relu(self.conv2(self.relu(self.conv1(x))))))
        x = self.batchnorm2(self.pool(self.relu(self.conv4(self.relu(self.conv3(x))))))
        x = self.pool(self.relu(self.conv5(x)))
        x = self.dropout(x.contiguous().view(-1, 128 * 4 * 4))
        x = self.dropout(self.relu(self.fc1(x)))
        return self.logSoftmax(self.fc2(x))
    def fit(self, epochs=30, lr=0.003):
        optimizer = optim.Adam(self.parameters(), lr=lr, weight_decay=1e-5)
        count, lossFunction = 0, nn.NLLLoss()
        lastTime, initialTime = (self.times[-1] if len(self.times) > 0 else 0), time.time()
        for epoch in range(epochs):
            for imgs, labels in dl:
                count += 1; optimizer.zero_grad(); imgs, labels = imgs.cuda(), labels.cuda()
                loss = lossFunction(self(imgs), labels); loss.backward(); optimizer.step()
                if count % 30 == 0:
                    self.eval()
                    test_imgs, test_labels = next(iter(dl_test));self.losses.append(loss.item())
                    self.accuracies.append((torch.argmax(self(test_imgs.cuda()), dim=1) == test_labels.cuda()).sum())
                    self.times.append(lastTime + time.time()-initialTime); self.train()
                    print(f"\rProgress: {np.round(100*count/(epochs*len(dl)))}%, loss: {self.losses[-1]}, accuracy: {self.accuracies[-1]}/100    ", end="")
        return torch.Tensor(self.losses), torch.Tensor(self.accuracies), torch.Tensor(self.times)
    def fitSpecial(self, cycles=1, lr=1e-3):
        optimizer = optim.Adam(self.parameters(), lr=lr, weight_decay=1e-5)
        lossFunction = nn.NLLLoss(); totalLoss = 0
        it = iter(dl)
        for cycle in range(cycles):
            imgs, labels = next(it); optimizer.zero_grad()
            loss = lossFunction(self(imgs.cuda()), labels.cuda())
            totalLoss += loss.item(); loss.backward(); optimizer.step()
        return totalLoss / (cycles * dl.batch_size)
    def getLoss(self):
        lossFunction=nn.NLLLoss(); imgs, labels = next(iter(dl))
        return lossFunction(self(imgs.cuda()), labels.cuda()).item() / dl.batch_size
net = Net().cuda(); net.load_state_dict(torch.load("models/cnn-standard.pth"))
totalAccuracy = 0; net.eval()
for test_imgs, test_labels in dl_test:
    totalAccuracy += (torch.argmax(net(test_imgs.cuda()), dim=1) == test_labels.cuda()).sum()
totalAccuracy/len(dl_test)
Out[33]:
tensor(81, device='cuda:0')
In [34]:
Net.importParams = importParams
Net.exportParams = exportParams
In [64]:
def findLr(self, cycles:int=2, begin:float=1e-6, end:float=100, steps:int=100, full:bool=False):
    lrs = torch.exp(torch.linspace(np.log(begin), np.log(end), steps))
    losses = []; initialLoss = None#self.getLoss(); 
    initialParams = self.exportParams()
    for lr in lrs:
        loss = self.fitSpecial(lr=lr, cycles=cycles)
        if initialLoss is None: initialLoss = loss
        losses.append(loss); self.importParams(initialParams);
        if not full:
            if loss > initialLoss * 1.1: losses.pop(); break
    return lrs[:len(losses)], torch.Tensor(losses)
Net.findLr = findLr

def LRFinder(net:Net=None, begin:float=1e-6, end:float=100, full:bool=False, rows=4, timeIt=False):
    currentTime = time.time(); net = Net().cuda() if net is None else net
    plt.figure(figsize=(10, 3*rows), dpi=350); initialParams = net.exportParams()
    for i in range(2*rows):
        plt.subplot(rows, 2, i+1); plt.title(f"Epoch {2**i}"); plt.grid(True)
        plt.xscale("log"); plt.plot(*net.findLr(int(2**i), begin=begin, end=end, full=full))
        net.importParams(initialParams); plt.xlabel("Learning rate"); plt.ylabel("Loss")
        plt.tight_layout(pad=1.0)
        if timeIt: print(f"Time taken for epoch {2**i}: {time.time() - currentTime}")
        currentTime = time.time()
    return net
In [51]:
net = LRFinder(timeIt=True)
Time taken for epoch 1: 4.740894079208374
Time taken for epoch 2: 4.104109048843384
Time taken for epoch 4: 7.598794221878052
Time taken for epoch 8: 15.492299795150757
Time taken for epoch 16: 30.622349739074707
Time taken for epoch 32: 67.27435350418091
Time taken for epoch 64: 143.91255927085876
Time taken for epoch 128: 317.8925120830536

This takes quite a while to run and apparently, the number of epochs useful enough to make predictions seems to be 16. Let's see whether training it shifts the distribution away or not:

In [54]:
net1 = Net().cuda()
net1.fit(epochs=64, lr=1e-3)
Progress: 100.0%, loss: 0.3174886405467987, accuracy: 87/100     
Out[54]:
(tensor([1.9282, 1.6275, 1.4798,  ..., 0.3938, 0.4670, 0.3175]),
 tensor([22., 40., 41.,  ..., 83., 88., 87.]),
 tensor([1.2976e+00, 2.6014e+00, 3.9511e+00,  ..., 1.7920e+03, 1.7934e+03,
         1.7962e+03]))
In [61]:
plt.figure(figsize=(10, 2.9), dpi=350)
plt.subplot(1, 2, 1); plt.grid(True); plt.plot(net1.times, net1.losses)
plt.subplot(1, 2, 2); plt.grid(True); plt.plot(net1.times, net1.accuracies); pass
In [65]:
LRFinder(net=net1, timeIt=True); pass
Time taken for epoch 1: 0.2124178409576416
Time taken for epoch 2: 4.547643184661865
Time taken for epoch 4: 0.6706664562225342
Time taken for epoch 8: 13.239860534667969
Time taken for epoch 16: 29.251240015029907
Time taken for epoch 32: 21.72825312614441
Time taken for epoch 64: 115.01335549354553
Time taken for epoch 128: 224.023912191391

Not really, I couldn't see any patterns here, because the loss when training normally just fluctuates up and down a whole lot, so we shouldn't expect to improve much after we have trained for a while. So sadly, we didn't see anything out of the ordinary for the CIFAR-10 network. Also, it seems like creating a network that can control its own learning rate is too complicated, and learning rates don't really have to change that much, and you basically can scan for the right value and get state-of-the-art results right away.